﻿using DotNetCommon.Extensions;
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace DBUtil
{
    public abstract partial class DBAccess
    {
        internal AsyncLocal<SessionContext> SessionContext = new();

        /// <summary>
        /// 获取一个新的 DbConnection 相当于 return new SqlConnection(db.DBConn)
        /// </summary>
        /// <remarks>注意: 返回的这个连接的打开、释放、事务等都需要自己处理</remarks>
        public DbConnection GetNewConnection() => GetConnectionByConnectString();
        protected abstract DbConnection GetConnectionByConnectString();

        /// <summary>
        /// 当前是否在事务中
        /// </summary>
        public bool IsTransaction => SessionContext.Value?.Transaction != null;
        /// <summary>
        /// 当前是否在事务中
        /// </summary>
        public bool IsSession => SessionContext.Value != null;
        /// <summary>
        /// 如果 IsSession=true 或 IsTransaction=true 那么返回当前使用的DbConnection
        /// </summary>
        protected DbConnection CurrentConnection => SessionContext.Value?.Connection;
        /// <summary>
        /// 如果 IsTransaction=true 那么返回当前的事务对象DbTransaction(慎重使用)
        /// </summary>
        protected DbTransaction CurrentTransaction => SessionContext.Value?.Transaction;
        /// <summary>
        /// 当前的事务隔离级别
        /// </summary>
        public IsolationLevel? TransactionLevel => SessionContext.Value?.Transaction?.IsolationLevel;

        #region RunInSessionAsync
        /// <summary>
        /// 开启一个会话
        /// </summary>
        public async Task RunInSessionAsync(Func<Task> func, bool forseNew = false)
        {
            await RunInSessionAsync(async () =>
            {
                await func();
                return 1;
            }, forseNew);
        }

        /// <summary>
        /// 开启一个会话
        /// </summary>
        public async Task<T> RunInSessionAsync<T>(Func<Task<T>> func, bool forseNew = false)
        {
            SessionContext ctx;
            if (SessionContext.Value == null || forseNew) ctx = SessionContext.Value = new SessionContext() { Connection = GetNewConnection() };
            else ctx = SessionContext.Value;
            ctx.Open();
            try { return await func(); }
            finally { ctx.Close(); }
        }

        /// <summary>
        /// 开启一个会话
        /// </summary>
        public void RunInSession(Action func, bool forseNew = false)
        {
            RunInSession(() =>
            {
                func();
                return 1;
            }, forseNew);
        }

        /// <summary>
        /// 开启一个会话
        /// </summary>
        public T RunInSession<T>(Func<T> func, bool forseNew = false)
        {
            return Task.Run(() =>
            {
                SessionContext ctx = null;
                if (SessionContext.Value == null || forseNew) ctx = SessionContext.Value = new SessionContext() { Connection = GetNewConnection() };
                else ctx = SessionContext.Value;
                ctx.Open();
                try { return func(); }
                finally { ctx.Close(); }
            }).Result;
        }
        #endregion

        #region RunInNoSession
        /// <summary>
        /// 暂时离开当前会话 执行代码
        /// </summary>
        public async Task RunInNoSessionAsync(Func<Task> func)
        {
            await RunInNoSessionAsync(async () =>
            {
                await func();
                return 1;
            });
        }

        /// <summary>
        /// 暂时离开当前会话 执行代码
        /// </summary>
        public async Task<T> RunInNoSessionAsync<T>(Func<Task<T>> func)
        {
            SessionContext.Value = null;
            return await func();
        }

        /// <summary>
        /// 暂时离开当前会话 执行代码
        /// </summary>
        public void RunInNoSession(Action func)
        {
            RunInNoSession(() =>
            {
                func();
                return 1;
            });
        }

        /// <summary>
        /// 暂时离开当前会话 执行代码
        /// </summary>
        public T RunInNoSession<T>(Func<T> func, bool forseNew = false)
        {
            return Task.Run(() =>
            {
                if (forseNew || SessionContext.Value != null) SessionContext.Value = null;
                return func();
            }).Result;
        }
        #endregion

        #region RunInTransactionAsync
        /// <summary>
        /// 在一个事务里运行
        /// </summary>
        public async Task RunInTransactionAsync(Func<Task> func, IsolationLevel? isolationLevel = null)
        {
            await RunInTransactionAsync(async () =>
            {
                await func();
                return 1;
            }, isolationLevel);
        }

        /// <summary>
        /// 在一个事务里运行
        /// </summary>
        public async Task<T> RunInTransactionAsync<T>(Func<Task<T>> func, IsolationLevel? isolationLevel = null)
        {
            return await RunInSessionAsync(async () =>
            {
                await SessionContext.Value.OpenTransactionAsync(isolationLevel);
                try
                {
                    var ret = await func();
                    await SessionContext.Value.CommitTransactionAsync();
                    return ret;
                }
                catch
                {
                    await SessionContext.Value.RollbackTransactionAsync();
                    throw;
                }
            });
        }

        /// <summary>
        /// 在一个事务里运行
        /// </summary>
        public void RunInTransaction(Action func, IsolationLevel? isolationLevel = null)
        {
            RunInTransaction(() =>
            {
                func();
                return 1;
            }, isolationLevel);
        }

        /// <summary>
        /// 在一个事务里运行
        /// </summary>
        public T RunInTransaction<T>(Func<T> func, IsolationLevel? isolationLevel = null)
        {
            return RunInSession(() =>
            {
                SessionContext.Value.OpenTransaction(isolationLevel);
                try
                {
                    var ret = func();
                    SessionContext.Value.CommitTransaction();
                    return ret;
                }
                catch
                {
                    SessionContext.Value.RollbackTransaction();
                    throw;
                }
            });
        }
        #endregion

        #region private RunInCommandAsync
        private async Task<T> RunInCommandAsync<T>(Func<DbCommand, Task<T>> func, string sql, CommandType? commandType = null, int? commandTimeout = null, IEnumerable<DbParameter> parameters = null)
        {
            var ctx = SessionContext.Value;
            if (ctx == null)
            {
                //非session环境
                DbConnection conn = null;
                DbCommand cmd = null;
                try
                {
                    conn = GetNewConnection();
                    cmd = conn.CreateCommand();
                    cmd.CommandText = sql;
                    if (commandTimeout != null) cmd.CommandTimeout = commandTimeout.Value;
                    if (commandType != null) cmd.CommandType = commandType.Value;
                    if (parameters.IsNotNullOrEmpty()) cmd.Parameters.AddRange(parameters.ToArray());
                    try
                    {
                        if (conn.State != ConnectionState.Open) await conn.OpenAsync();
                    }
                    catch (Exception ex)
                    {
                        if (ex.Message?.Contains("SSL Authentication Error") == true)
                        {
                            throw new Exception($"db链接报错: {ex.Message}, 也许可以尝试在链接字符串中加入: [SslMode=none;AllowPublicKeyRetrieval=True;]", ex);
                        }
                        throw;
                    }
                    var ret = await func(cmd);
                    return ret;
                }
                finally
                {
                    //释放cmd
                    try
                    {
                        if (cmd != null)
                        {
                            cmd.Parameters.Clear();
                            cmd.Transaction = null;
                            cmd.CommandText = null;
                            await cmd.DisposeAsync();
                        }
                    }
                    catch (Exception ex) { logger.LogError(ex, $"释放 DbCommand 时报错: {ex.Message}"); }
                    try
                    {
                        if (conn?.State != ConnectionState.Closed) await conn.CloseAsync();
                    }
                    catch (Exception ex) { logger.LogError(ex, $"释放 DbConnection 时报错: {ex.Message}"); }
                }
            }
            else
            {
                //session环境
                var cmd = ctx.Connection.CreateCommand();
                cmd.CommandText = sql;
                if (ctx.Transaction != null) cmd.Transaction = ctx.Transaction;
                if (commandTimeout != null) cmd.CommandTimeout = commandTimeout.Value;
                if (commandType != null) cmd.CommandType = commandType.Value;
                if (parameters.IsNotNullOrEmpty()) cmd.Parameters.AddRange(parameters.ToArray());
                try
                {
                    var ret = await func(cmd);
                    return ret;
                }
                finally
                {
                    //释放cmd
                    try
                    {
                        cmd.Parameters.Clear();
                        cmd.Transaction = null;
                        cmd.CommandText = null;
                        await cmd.DisposeAsync();
                    }
                    catch (Exception ex) { logger.LogError(ex, $"释放 DbCommand 时报错: {ex.Message}"); }
                }
            }
        }
        private async Task RunInCommandAsync(Func<DbCommand, Task> func, string sql, CommandType? commandType = null, int? commandTimeout = null, IEnumerable<DbParameter> parameters = null)
        {
            await RunInCommandAsync(async cmd =>
            {
                await func(cmd);
                return 1;
            }, sql, commandType, commandTimeout, parameters);
        }
        private T RunInCommand<T>(Func<DbCommand, T> func, string sql, CommandType? commandType = null, int? commandTimeout = null, IEnumerable<DbParameter> parameters = null)
        {
            var ctx = SessionContext.Value;
            if (ctx == null)
            {
                //非session环境
                DbConnection conn = null;
                DbCommand cmd = null;
                try
                {
                    conn = GetNewConnection();
                    cmd = conn.CreateCommand();
                    cmd.CommandText = sql;
                    if (commandTimeout != null) cmd.CommandTimeout = commandTimeout.Value;
                    if (commandType != null) cmd.CommandType = commandType.Value;
                    if (parameters.IsNotNullOrEmpty()) cmd.Parameters.AddRange(parameters.ToArray());
                    try
                    {
                        if (conn.State != ConnectionState.Open) conn.Open();
                    }
                    catch (Exception ex)
                    {
                        if (ex.Message?.Contains("SSL Authentication Error") == true)
                        {
                            throw new Exception($"db链接报错: {ex.Message}, 也许可以尝试在链接字符串中加入: [SslMode=none;AllowPublicKeyRetrieval=True;]", ex);
                        }
                        throw;
                    }
                    var ret = func(cmd);
                    return ret;
                }
                finally
                {
                    //释放cmd
                    try
                    {
                        if (cmd != null)
                        {
                            cmd.Parameters.Clear();
                            cmd.Transaction = null;
                            cmd.CommandText = null;
                            cmd.Dispose();
                        }
                    }
                    catch (Exception ex) { logger.LogError(ex, $"释放 DbCommand 时报错: {ex.Message}"); }
                    try
                    {
                        if (conn?.State != ConnectionState.Closed) conn.Close();
                    }
                    catch (Exception ex) { logger.LogError(ex, $"释放 DbConnection 时报错: {ex.Message}"); }
                }
            }
            else
            {
                //session环境
                var cmd = ctx.Connection.CreateCommand();
                cmd.CommandText = sql;
                if (ctx.Transaction != null) cmd.Transaction = ctx.Transaction;
                if (commandTimeout != null) cmd.CommandTimeout = commandTimeout.Value;
                if (commandType != null) cmd.CommandType = commandType.Value;
                if (parameters.IsNotNullOrEmpty()) cmd.Parameters.AddRange(parameters.ToArray());
                try
                {
                    var ret = func(cmd);
                    return ret;
                }
                finally
                {
                    //释放cmd
                    try
                    {
                        cmd.Parameters.Clear();
                        cmd.Transaction = null;
                        cmd.CommandText = null;
                        cmd.Dispose();
                    }
                    catch (Exception ex) { logger.LogError(ex, $"释放 DbCommand 时报错: {ex.Message}"); }
                }
            }
        }
        private void RunInCommand(Action<DbCommand> func, string sql, CommandType? commandType = null, int? commandTimeout = null, IEnumerable<DbParameter> parameters = null)
        {
            RunInCommand(cmd =>
            {
                func(cmd);
                return 1;
            }, sql, commandType, commandTimeout, parameters);
        }
        #endregion

        #region RunWhenCommit
        /// <summary>
        /// 当事务提交后运行(如果没有事务则立即运行)
        /// </summary>
        public async Task RunWhenCommitAsync(Func<Task> func)
        {
            ArgumentNullException.ThrowIfNull(func, nameof(func));
            if (IsTransaction) await SessionContext.Value.RunWhenCommitAsync(func);
            else await func.Invoke();
        }

        /// <summary>
        /// 当事务提交后运行(如果没有事务则立即运行)
        /// </summary>
        public void RunWhenCommit(Action func)
        {
            ArgumentNullException.ThrowIfNull(func, nameof(func));
            if (IsTransaction) SessionContext.Value.RunWhenCommit(func);
            else func.Invoke();
        }
        #endregion

        #region 事务保存点
        /// <summary>
        /// 当前是否支持事务保存点(在事务中并且db支持)
        /// </summary>
        public virtual bool SupportSavePoint => IsTransaction && CurrentTransaction.SupportsSavepoints;
        public void SaveTransactionPoint(string savepointName)
        {
            if (!SupportSavePoint) throw new Exception($"当前不在事务中或数据库不支持事务保存点!");
            CurrentTransaction.Save(savepointName);
        }
        public void RollbackTransactionPoint(string savepointName)
        {
            if (!SupportSavePoint) throw new Exception($"当前不在事务中或数据库不支持事务保存点!");
            CurrentTransaction.Rollback(savepointName);
        }
        public void ReleaseTransactionPoint(string savepointName)
        {
            if (!SupportSavePoint) throw new Exception($"当前不在事务中或数据库不支持事务保存点!");
            CurrentTransaction.Release(savepointName);
        }
        public async Task SaveTransactionPointAsync(string savepointName)
        {
            if (!SupportSavePoint) throw new Exception($"当前不在事务中或数据库不支持事务保存点!");
            await CurrentTransaction.SaveAsync(savepointName);
        }
        public async Task RollbackTransactionPointAsync(string savepointName)
        {
            if (!SupportSavePoint) throw new Exception($"当前不在事务中或数据库不支持事务保存点!");
            await CurrentTransaction.RollbackAsync(savepointName);
        }
        public async Task ReleaseTransactionPointAsync(string savepointName)
        {
            if (!SupportSavePoint) throw new Exception($"当前不在事务中或数据库不支持事务保存点!");
            await CurrentTransaction.ReleaseAsync(savepointName);
        }
        #endregion
    }

    internal class SessionContext
    {
        public DbTransaction Transaction { get; set; }
        public Stack<string> SavePoints { get; set; }
        private int _tranCount = 0;
        public int IncrementTransaction()
        {
            return Interlocked.Increment(ref _tranCount);
        }
        public int DecrementTransaction()
        {
            return Interlocked.Decrement(ref _tranCount);
        }

        public DbConnection Connection { get; set; }
        private int _openCount = 0;

        public int IncrementConnect()
        {
            return Interlocked.Increment(ref _openCount);
        }
        public int DecrementConnect()
        {
            return Interlocked.Decrement(ref _openCount);
        }

        private readonly List<Func<Task>> Funcs = [];
        public async Task RunWhenCommitAsync(Func<Task> func)
        {
            Funcs.Add(func);
        }

        public void RunWhenCommit(Action func)
        {
            Funcs.Add(() =>
            {
                func();
                return Task.CompletedTask;
            });
        }

        #region 异步
        public async Task OpenAsync()
        {
            //完成后才能加一
            try
            {
                if (_openCount == 0 && Connection.State == ConnectionState.Closed) await Connection.OpenAsync();
                IncrementConnect();
            }
            catch (Exception ex)
            {
                if (ex.Message?.Contains("SSL Authentication Error") == true)
                {
                    throw new Exception($"db链接报错: {ex.Message}, 也许可以尝试在链接字符串中加入: [SslMode=none;AllowPublicKeyRetrieval=True;]", ex);
                }
                throw;
            }
        }

        public async Task CloseAsync()
        {
            //先减一 无论成功与否
            var count = DecrementConnect();
            if (count == 0 && Connection.State != ConnectionState.Closed) await Connection.CloseAsync();
        }

        public async Task OpenTransactionAsync(IsolationLevel? isolationLevel)
        {
            if (_tranCount == 0)
            {
                //尚未开启事务 完成后才能加一
                isolationLevel ??= IsolationLevel.ReadCommitted;
                Transaction = await Connection.BeginTransactionAsync(isolationLevel.Value);
                IncrementTransaction();
            }
            else
            {
                //已经开启了事务 完成后才能加一
                //校验
                if (isolationLevel != null)
                    if (isolationLevel > Transaction.IsolationLevel) throw new Exception($"新的事务隔离级别不能大于已开启的事务隔离级别({Transaction.IsolationLevel} => {isolationLevel})!");
                //开启事务保存点
                SavePoints ??= new Stack<string>();
                var point = $"{_tranCount + 1}_{Guid.NewGuid():N}";
                await Transaction.SaveAsync(point);
                SavePoints.Push(point);
                IncrementTransaction();
            }
        }

        public async Task CommitTransactionAsync()
        {
            if (_tranCount == 1)
            {
                //没有内层事务了
                try
                {
                    await Transaction.CommitAsync();
                }
                finally
                {
                    //无论是否成功 清理工作都要继续
                    Transaction = null;
                    DecrementTransaction();
                }
                if (Funcs.IsNotNullOrEmpty())
                    foreach (var item in Funcs) item.Invoke().Wait();
            }
            else
            {
                //还不能提交事务
                //释放保存点
                var point = SavePoints.Pop();
                try
                {
                    await Transaction.ReleaseAsync(point);
                }
                finally
                {
                    //无论是否成功 清理工作都要继续
                    DecrementTransaction();
                }
            }
        }

        public async Task RollbackTransactionAsync()
        {
            if (_tranCount == 1)
            {
                //没有内层事务了
                try
                {
                    await Transaction.RollbackAsync();
                }
                finally
                {
                    //无论是否成功 清理工作都要继续
                    Transaction = null;
                    DecrementTransaction();
                }
            }
            else
            {
                //还不能提交事务
                //释放保存点
                var point = SavePoints.Pop();
                try
                {
                    await Transaction.RollbackAsync(point);
                }
                finally
                {
                    //无论是否成功 清理工作都要继续
                    DecrementTransaction();
                }
            }
        }
        #endregion

        #region 同步
        public void Open()
        {
            //完成后才能加一
            try
            {
                if (_openCount == 0 && Connection.State == System.Data.ConnectionState.Closed) Connection.Open();
                IncrementConnect();
            }
            catch (Exception ex)
            {
                if (ex.Message?.Contains("SSL Authentication Error") == true)
                {
                    throw new Exception($"db链接报错: {ex.Message}, 也许可以尝试在链接字符串中加入: [SslMode=none;AllowPublicKeyRetrieval=True;]", ex);
                }
                throw;
            }
        }

        public void Close()
        {
            //先减一 无论成功与否
            var count = DecrementConnect();
            if (count == 0 && Connection.State != System.Data.ConnectionState.Closed) Connection.Close();
        }

        public void OpenTransaction(IsolationLevel? isolationLevel)
        {
            if (_tranCount == 0)
            {
                //尚未开启事务 完成后才能加一
                isolationLevel ??= IsolationLevel.ReadCommitted;
                Transaction = Connection.BeginTransaction(isolationLevel.Value);
                IncrementTransaction();
            }
            else
            {
                //已经开启了事务 完成后才能加一
                //校验
                if (isolationLevel != null)
                    if (isolationLevel > Transaction.IsolationLevel) throw new Exception($"新的事务隔离级别不能大于已开启的事务隔离级别({Transaction.IsolationLevel} => {isolationLevel})!");
                //开启事务保存点
                SavePoints ??= new Stack<string>();
                var point = $"{_tranCount + 1}_{Guid.NewGuid():N}";
                Transaction.Save(point);
                SavePoints.Push(point);
                IncrementTransaction();
            }
        }

        public void CommitTransaction()
        {
            if (_tranCount == 1)
            {
                //没有内层事务了
                try
                {
                    Transaction.Commit();
                }
                finally
                {
                    //无论是否成功 清理工作都要继续
                    Transaction = null;
                    DecrementTransaction();
                }
                if (Funcs.IsNotNullOrEmpty())
                    foreach (var item in Funcs) item.Invoke().Wait();
            }
            else
            {
                //还不能提交事务
                //释放保存点
                var point = SavePoints.Pop();
                try
                {
                    Transaction.Release(point);
                }
                finally
                {
                    //无论是否成功 清理工作都要继续
                    DecrementTransaction();
                }
            }
        }

        public void RollbackTransaction()
        {
            if (_tranCount == 1)
            {
                //没有内层事务了
                try
                {
                    Transaction.Rollback();
                }
                finally
                {
                    //无论是否成功 清理工作都要继续
                    Transaction = null;
                    DecrementTransaction();
                }
            }
            else
            {
                //还不能提交事务
                //释放保存点
                var point = SavePoints.Pop();
                try
                {
                    Transaction.Rollback(point);
                }
                finally
                {
                    //无论是否成功 清理工作都要继续
                    DecrementTransaction();
                }
            }
        }
        #endregion
    }
}
